# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
import math
import json
import copy
from typing import List, Dict
import numpy as np
import torch
if torch.__version__ >= '1.8':
    import torch_npu
from torch import nn
from torch.nn import functional as F
import datetime
import pdb

from detectron2.modeling.proposal_generator.build import PROPOSAL_GENERATOR_REGISTRY
from detectron2.layers import ShapeSpec, cat
from detectron2.structures import Instances, Boxes
from detectron2.modeling import detector_postprocess
from detectron2.utils.comm import get_world_size
from torchvision.ops import nms 

from ..layers.heatmap_focal_loss import heatmap_focal_loss_jit
from ..layers.heatmap_focal_loss import binary_heatmap_focal_loss_jit
from ..layers.iou_loss import IOULoss
from ..layers.ml_nms import ml_nms
from ..debug import debug_train, debug_test
from .utils import reduce_sum, _transpose
from .centernet_head import CenterNetHead

__all__ = ["CenterNet"]

INF = 100000000

@PROPOSAL_GENERATOR_REGISTRY.register()
class CenterNet(nn.Module):
    def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]):
        super().__init__()
        self.num_classes = cfg.MODEL.CENTERNET.NUM_CLASSES
        self.in_features = cfg.MODEL.CENTERNET.IN_FEATURES
        self.strides = cfg.MODEL.CENTERNET.FPN_STRIDES
        self.score_thresh = cfg.MODEL.CENTERNET.INFERENCE_TH
        self.min_radius = cfg.MODEL.CENTERNET.MIN_RADIUS
        self.hm_focal_alpha = cfg.MODEL.CENTERNET.HM_FOCAL_ALPHA
        self.hm_focal_beta = cfg.MODEL.CENTERNET.HM_FOCAL_BETA
        self.loss_gamma = cfg.MODEL.CENTERNET.LOSS_GAMMA
        self.reg_weight = cfg.MODEL.CENTERNET.REG_WEIGHT
        self.not_norm_reg = cfg.MODEL.CENTERNET.NOT_NORM_REG
        self.with_agn_hm = cfg.MODEL.CENTERNET.WITH_AGN_HM
        self.only_proposal = cfg.MODEL.CENTERNET.ONLY_PROPOSAL
        self.as_proposal = cfg.MODEL.CENTERNET.AS_PROPOSAL
        self.not_nms = cfg.MODEL.CENTERNET.NOT_NMS
        self.pos_weight = cfg.MODEL.CENTERNET.POS_WEIGHT
        self.neg_weight = cfg.MODEL.CENTERNET.NEG_WEIGHT
        self.sigmoid_clamp = cfg.MODEL.CENTERNET.SIGMOID_CLAMP
        self.ignore_high_fp = cfg.MODEL.CENTERNET.IGNORE_HIGH_FP
        self.center_nms = cfg.MODEL.CENTERNET.CENTER_NMS
        self.sizes_of_interest = cfg.MODEL.CENTERNET.SOI
        self.more_pos = cfg.MODEL.CENTERNET.MORE_POS
        self.more_pos_thresh = cfg.MODEL.CENTERNET.MORE_POS_THRESH
        self.more_pos_topk = cfg.MODEL.CENTERNET.MORE_POS_TOPK
        self.pre_nms_topk_train = cfg.MODEL.CENTERNET.PRE_NMS_TOPK_TRAIN
        self.pre_nms_topk_test = cfg.MODEL.CENTERNET.PRE_NMS_TOPK_TEST
        self.post_nms_topk_train = cfg.MODEL.CENTERNET.POST_NMS_TOPK_TRAIN
        self.post_nms_topk_test = cfg.MODEL.CENTERNET.POST_NMS_TOPK_TEST
        self.nms_thresh_train = cfg.MODEL.CENTERNET.NMS_TH_TRAIN
        self.nms_thresh_test = cfg.MODEL.CENTERNET.NMS_TH_TEST
        self.debug  = cfg.DEBUG
        self.vis_thresh = cfg.VIS_THRESH
        if self.center_nms:
            self.not_nms = True
        self.iou_loss = IOULoss(cfg.MODEL.CENTERNET.LOC_LOSS_TYPE)
        assert (not self.only_proposal) or self.with_agn_hm
        # delta for rendering heatmap
        self.delta = (1 - cfg.MODEL.CENTERNET.HM_MIN_OVERLAP) \
            / (1 + cfg.MODEL.CENTERNET.HM_MIN_OVERLAP)
        
        input_shape_head = [input_shape[f] for f in self.in_features]
        self.centernet_head = CenterNetHead(cfg, input_shape_head)

        if self.debug:
            pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(
                torch.device(cfg.MODEL.DEVICE)).view(3, 1, 1)
            pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(
                torch.device(cfg.MODEL.DEVICE)).view(3, 1, 1)
            self.denormalizer = lambda x: x * pixel_std + pixel_mean


    def forward(self, images, features_dict, gt_instances):
        # import pdb; pdb.set_trace()
        features = [features_dict[f] for f in self.in_features]
        clss_per_level, reg_pred_per_level, agn_hm_pred_per_level = \
            self.centernet_head(features)
        grids = self.compute_grids(features)
        shapes_per_level = grids[0].new_tensor(
                    [(x.shape[2], x.shape[3]) for x in reg_pred_per_level])
        if not self.training:
            return self.inference(
                images, clss_per_level, reg_pred_per_level, 
                agn_hm_pred_per_level, grids)
        else:
            pos_inds, labels, reg_targets, flattened_hms = \
                self._get_ground_truth(
                    grids, shapes_per_level, gt_instances)
            
            # logits_pred: M x F, reg_pred: M x 4, agn_hm_pred: M
            logits_pred, reg_pred, agn_hm_pred = self._flatten_outputs(
                clss_per_level, reg_pred_per_level, agn_hm_pred_per_level)

            if self.more_pos:
                # add more pixels as positive if \
                #   1. they are within the center3x3 region of an object
                #   2. their regression losses are small (<self.more_pos_thresh)
                pos_inds, labels = self._add_more_pos(
                    reg_pred, gt_instances, shapes_per_level)
            losses = self.losses(
                pos_inds, labels, reg_targets, flattened_hms,
                logits_pred, reg_pred, agn_hm_pred)
            proposals = None
            if self.only_proposal: #forward 
                agn_hm_pred_per_level = [x.sigmoid() for x in agn_hm_pred_per_level]
                proposals = self.predict_instances(
                    grids, agn_hm_pred_per_level, reg_pred_per_level, 
                    images.image_sizes, [None for _ in agn_hm_pred_per_level])
            elif self.as_proposal: # category specific bbox as agnostic proposals
                clss_per_level = [x.sigmoid() for x in clss_per_level]
                proposals = self.predict_instances(
                    grids, clss_per_level, reg_pred_per_level, 
                    images.image_sizes, agn_hm_pred_per_level)
            if self.only_proposal or self.as_proposal:
                for p in range(len(proposals)):
                    # pdb.set_trace()
                    proposals[p].proposal_boxes = proposals[p].get('pred_boxes')
                    proposals[p].objectness_logits = proposals[p].get('scores')
                    proposals[p].remove('pred_boxes')
                    proposals[p].remove('scores')
                    proposals[p].remove('pred_classes')
            if self.debug:
                debug_train(
                    [self.denormalizer(x) for x in images], 
                    gt_instances, flattened_hms, reg_targets, 
                    labels, pos_inds, shapes_per_level, grids, self.strides)
            # import pdb; pdb.set_trace()
            return proposals, losses


    def losses(
        self, pos_inds, labels, reg_targets, flattened_hms,
        logits_pred, reg_pred, agn_hm_pred):
        '''
        Inputs:
            pos_inds: N
            labels: N
            reg_targets: M x 4
            flattened_hms: M x C
            logits_pred: M x C
            reg_pred: M x 4
            agn_hm_pred: M x 1 or None
            N: number of positive locations in all images
            M: number of pixels from all FPN levels
            C: number of classes
        '''
        assert (torch.isfinite(reg_pred).all().item())
        num_pos_local = pos_inds.numel()
        num_gpus = get_world_size()
        total_num_pos = reduce_sum(
            pos_inds.new_tensor([num_pos_local]).int()).item()
        num_pos_avg = max(total_num_pos / num_gpus, 1.0)

        losses = {}
        if not self.only_proposal:
            pos_loss, neg_loss = heatmap_focal_loss_jit(
                logits_pred, flattened_hms, pos_inds, labels,
                alpha=self.hm_focal_alpha, 
                beta=self.hm_focal_beta, 
                gamma=self.loss_gamma, 
                reduction='sum',
                sigmoid_clamp=self.sigmoid_clamp,
                ignore_high_fp=self.ignore_high_fp,
            )
            pos_loss = self.pos_weight * pos_loss / num_pos_avg
            neg_loss = self.neg_weight * neg_loss / num_pos_avg
            losses['loss_centernet_pos'] = pos_loss
            losses['loss_centernet_neg'] = neg_loss
        
        # reg_inds = torch.nonzero(reg_targets.max(dim=1)[0] >= 0).squeeze(1)
        # reg_pred = reg_pred[reg_inds]
        # reg_targets_pos = reg_targets[reg_inds]
        # reg_weight_map = flattened_hms.max(dim=1)[0]
        # reg_weight_map = reg_weight_map[reg_inds]
        # reg_weight_map = reg_weight_map * 0 + 1 \
        #     if self.not_norm_reg else reg_weight_map
        # reg_norm = max(reduce_sum(reg_weight_map.sum()).item() / num_gpus, 1)

        #fixed shape start
        reg_ind = (reg_targets.max(dim=1)[0] >= 0).unsqueeze(1).float()
        reg_pred = reg_pred * reg_ind
        reg_targets_pos = reg_targets * reg_ind
        reg_norm = max(reduce_sum(reg_ind.int().sum()).item() / num_gpus, 1)
        reg_weight_map = reg_ind.squeeze(1)

        reg_loss = self.reg_weight * self.iou_loss(
            reg_pred, reg_targets_pos, reg_weight_map,
            reduction='sum') / reg_norm
        losses['loss_centernet_loc'] = reg_loss

        if self.with_agn_hm:
            cat_agn_heatmap = flattened_hms.max(dim=1)[0] # M
            agn_pos_loss, agn_neg_loss = binary_heatmap_focal_loss_jit(
                agn_hm_pred, cat_agn_heatmap, pos_inds, num_pos_avg,
                alpha=self.hm_focal_alpha, 
                beta=self.hm_focal_beta, 
                gamma=self.loss_gamma,
                sigmoid_clamp=self.sigmoid_clamp,
                ignore_high_fp=self.ignore_high_fp,
            )
            # import pdb; pdb.set_trace()
            # agn_pos_loss = self.pos_weight * agn_pos_loss / num_pos_avg
            # agn_neg_loss = self.neg_weight * agn_neg_loss / num_pos_avg
            agn_pos_loss = self.pos_weight * agn_pos_loss
            agn_neg_loss = self.neg_weight * agn_neg_loss
            losses['loss_centernet_agn_pos'] = agn_pos_loss
            losses['loss_centernet_agn_neg'] = agn_neg_loss
    
        if self.debug:
            print('losses', losses)
            print('total_num_pos', total_num_pos)
        return losses


    def compute_grids(self, features):
        grids = []
        for level, feature in enumerate(features):
            h, w = feature.size()[-2:]
            shifts_x = torch.arange(
                0, w * self.strides[level], 
                step=self.strides[level],
                dtype=torch.float32, device=feature.device)
            shifts_y = torch.arange(
                0, h * self.strides[level], 
                step=self.strides[level],
                dtype=torch.float32, device=feature.device)
            shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
            shift_x = shift_x.reshape(-1)
            shift_y = shift_y.reshape(-1)
            grids_per_level = torch.stack((shift_x, shift_y), dim=1) + \
                self.strides[level] // 2
            grids.append(grids_per_level)
        return grids


    def _get_ground_truth(self, grids, shapes_per_level, gt_instances):
        '''
        Input:
            grids: list of tensors [(hl x wl, 2)]_l
            shapes_per_level: list of tuples L x 2:
            gt_instances: gt instances
        Retuen:
            pos_inds: N
            labels: N
            reg_targets: M x 4
            flattened_hms: M x C or M x 1
            N: number of objects in all images
            M: number of pixels from all FPN levels
        '''

        # get positive pixel index
        if not self.more_pos:
            pos_inds, labels = self._get_label_inds(
                gt_instances, shapes_per_level) 
        else:
            pos_inds, labels = None, None
        heatmap_channels = self.num_classes
        L = len(grids)
        num_loc_list = [len(loc) for loc in grids]
        strides = torch.cat([
            shapes_per_level.new_ones(num_loc_list[l]) * self.strides[l] \
            for l in range(L)]).float() # M
        reg_size_ranges = torch.cat([
            shapes_per_level.new_tensor(self.sizes_of_interest[l]).float().view(
            1, 2).expand(num_loc_list[l], 2) for l in range(L)]) # M x 2
        grids = torch.cat(grids, dim=0) # M x 2
        M = grids.shape[0]
        reg_targets = []
        flattened_hms = []
        # pdb.set_trace()
        for i in range(len(gt_instances)): # images
            boxes = gt_instances[i].gt_boxes.tensor # N x 4
            area = gt_instances[i].gt_boxes.area() # N
            gt_classes = gt_instances[i].gt_classes # N in [0, self.num_classes]

            N = boxes.shape[0]
            if N == 0:
                reg_targets.append(grids.new_zeros((M, 4)) - INF)
                flattened_hms.append(
                    grids.new_zeros((
                        M, 1 if self.only_proposal else heatmap_channels)))
                continue
            
            l = grids[:, 0].view(M, 1) - boxes[:, 0].view(1, N) # M x N
            t = grids[:, 1].view(M, 1) - boxes[:, 1].view(1, N) # M x N
            r = boxes[:, 2].view(1, N) - grids[:, 0].view(M, 1) # M x N
            b = boxes[:, 3].view(1, N) - grids[:, 1].view(M, 1) # M x N
            reg_target = torch.stack([l, t, r, b], dim=2) # M x N x 4

            centers = ((boxes[:, [0, 1]] + boxes[:, [2, 3]]) / 2) # N x 2
            centers_expanded = centers.view(1, N, 2).expand(M, N, 2) # M x N x 2
            strides_expanded = strides.view(M, 1, 1).expand(M, N, 2)
            centers_discret = ((centers_expanded / strides_expanded).int() * \
                strides_expanded).float() + strides_expanded / 2 # M x N x 2
            
            is_peak = (((grids.view(M, 1, 2).expand(M, N, 2) - \
                centers_discret) ** 2).sum(dim=2) == 0) # M x N
            is_in_boxes = reg_target.min(dim=2)[0] > 0 # M x N
            is_center3x3 = self.get_center3x3(
                grids, centers, strides) & is_in_boxes # M x N
            is_cared_in_the_level = self.assign_reg_fpn(
                reg_target, reg_size_ranges) # M x N
            reg_mask = is_center3x3 & is_cared_in_the_level # M x N

            dist2 = ((grids.view(M, 1, 2).expand(M, N, 2) - \
                centers_expanded) ** 2).sum(dim=2) # M x N
            dist2[is_peak] = 0
            radius2 = self.delta ** 2 * 2 * area # N
            radius2 = torch.clamp(
                radius2, min=self.min_radius ** 2)
            weighted_dist2 = dist2 / radius2.view(1, N).expand(M, N) # M x N            
            reg_target = self._get_reg_targets(
                reg_target, weighted_dist2.clone(), reg_mask, area) # M x 4

            if self.only_proposal:
                flattened_hm = self._create_agn_heatmaps_from_dist(
                    weighted_dist2.clone()) # M x 1
            else:
                flattened_hm = self._create_heatmaps_from_dist(
                    weighted_dist2.clone(), gt_classes, 
                    channels=heatmap_channels) # M x C

            reg_targets.append(reg_target)
            flattened_hms.append(flattened_hm)
        

        # transpose im first training_targets to level first ones
        reg_targets = _transpose(reg_targets, num_loc_list)
        flattened_hms = _transpose(flattened_hms, num_loc_list)
        for l in range(len(reg_targets)):
            reg_targets[l] = reg_targets[l] / float(self.strides[l])
        reg_targets = cat([x for x in reg_targets], dim=0) # MB x 4
        flattened_hms = cat([x for x in flattened_hms], dim=0) # MB x C
        
        return pos_inds, labels, reg_targets, flattened_hms


    def _get_label_inds(self, gt_instances, shapes_per_level):
        '''
        Inputs:
            gt_instances: [n_i], sum n_i = N
            shapes_per_level: L x 2 [(h_l, w_l)]_L
        Returns:
            pos_inds: N'
            labels: N'
        '''
        # import pdb; pdb.set_trace()
        pos_inds = []
        labels = []
        L = len(self.strides)
        B = len(gt_instances)
        shapes_per_level = shapes_per_level.long()
        loc_per_level = (shapes_per_level[:, 0] * shapes_per_level[:, 1]).long() # L
        level_bases = []
        s = 0
        for l in range(L):
            level_bases.append(s)
            s = s + B * loc_per_level[l]
        level_bases = shapes_per_level.new_tensor(level_bases).long() # L
        strides_default = shapes_per_level.new_tensor(self.strides).float() # L
        for im_i in range(B):
            targets_per_im = gt_instances[im_i]
            bboxes = targets_per_im.gt_boxes.tensor # n x 4
            n = bboxes.shape[0]
            centers = ((bboxes[:, [0, 1]] + bboxes[:, [2, 3]]) / 2) # n x 2
            centers = centers.view(n, 1, 2).expand(n, L, 2)
            strides = strides_default.view(1, L, 1).expand(n, L, 2)
            centers_inds = (centers / strides).long() # n x L x 2
            Ws = shapes_per_level[:, 1].view(1, L).expand(n, L)
            pos_ind = level_bases.view(1, L).expand(n, L) + \
                       im_i * loc_per_level.view(1, L).expand(n, L) + \
                       centers_inds[:, :, 1] * Ws + \
                       centers_inds[:, :, 0] # n x L
            is_cared_in_the_level = self.assign_fpn_level(bboxes)
            pos_ind = pos_ind[is_cared_in_the_level].view(-1)
            label = targets_per_im.gt_classes.view(
                n, 1).expand(n, L)[is_cared_in_the_level].view(-1)

            pos_inds.append(pos_ind) # n'
            labels.append(label) # n'
        pos_inds = torch.cat(pos_inds, dim=0).long()
        labels = torch.cat(labels, dim=0)
        return pos_inds, labels # N, N


    def assign_fpn_level(self, boxes):
        '''
        Inputs:
            boxes: n x 4
            size_ranges: L x 2
        Return:
            is_cared_in_the_level: n x L
        '''
        size_ranges = boxes.new_tensor(
            self.sizes_of_interest).view(len(self.sizes_of_interest), 2) # L x 2
        crit = ((boxes[:, 2:] - boxes[:, :2]) **2).sum(dim=1) ** 0.5 / 2 # n
        crit = torch.floor(crit)
        n, L = crit.shape[0], size_ranges.shape[0]
        crit = crit.view(n, 1).expand(n, L)
        size_ranges_expand = size_ranges.view(1, L, 2).expand(n, L, 2)
        is_cared_in_the_level = (crit >= size_ranges_expand[:, :, 0]) & \
            (crit <= size_ranges_expand[:, :, 1])
        return is_cared_in_the_level
    

    def assign_reg_fpn(self, reg_targets_per_im, size_ranges):
        '''
        TODO (Xingyi): merge it with assign_fpn_level
        Inputs:
            reg_targets_per_im: M x N x 4
            size_ranges: M x 2
        '''
        crit = ((reg_targets_per_im[:, :, :2] + \
            reg_targets_per_im[:, :, 2:])**2).sum(dim=2) ** 0.5 / 2 # M x N
        is_cared_in_the_level = (crit >= size_ranges[:, [0]]) & \
            (crit <= size_ranges[:, [1]])
        return is_cared_in_the_level


    def _get_reg_targets(self, reg_targets, dist, mask, area):
        '''
          reg_targets (M x N x 4): long tensor
          dist (M x N)
          is_*: M x N
        '''
        dist[mask == 0] = INF * 1.0
        min_dist, min_inds = dist.min(dim=1) # M
        reg_targets_per_im = reg_targets[
            range(len(reg_targets)), min_inds] # M x N x 4 --> M x 4
        reg_targets_per_im[min_dist == INF] = - INF
        return reg_targets_per_im


    def _create_heatmaps_from_dist(self, dist, labels, channels):
        '''
        dist: M x N
        labels: N
        return:
          heatmaps: M x C
        '''
        heatmaps = dist.new_zeros((dist.shape[0], channels))
        for c in range(channels):
            inds = (labels == c) # N
            if inds.int().sum() == 0:
                continue
            heatmaps[:, c] = torch.exp(-dist[:, inds].min(dim=1)[0])
            zeros = heatmaps[:, c] < 1e-4
            heatmaps[zeros, c] = 0
        return heatmaps


    def _create_agn_heatmaps_from_dist(self, dist):
        '''
        TODO (Xingyi): merge it with _create_heatmaps_from_dist
        dist: M x N
        return:
          heatmaps: M x 1
        '''
        heatmaps = dist.new_zeros((dist.shape[0], 1))
        heatmaps[:, 0] = torch.exp(-dist.min(dim=1)[0])
        zeros = heatmaps < 1e-4
        heatmaps[zeros] = 0
        return heatmaps


    def _flatten_outputs(self, clss, reg_pred, agn_hm_pred):
        # Reshape: (N, F, Hl, Wl) -> (N, Hl, Wl, F) -> (sum_l N*Hl*Wl, F)
        clss = cat([x.permute(0, 2, 3, 1).reshape(-1, x.shape[1]) \
            for x in clss], dim=0) if clss[0] is not None else None
        reg_pred = cat(
            [x.permute(0, 2, 3, 1).reshape(-1, 4) for x in reg_pred], dim=0)            
        agn_hm_pred = cat([x.permute(0, 2, 3, 1).reshape(-1) \
            for x in agn_hm_pred], dim=0) if self.with_agn_hm else None
        return clss, reg_pred, agn_hm_pred


    def get_center3x3(self, locations, centers, strides):
        '''
        Inputs:
            locations: M x 2
            centers: N x 2
            strides: M
        '''
        M, N = locations.shape[0], centers.shape[0]
        locations_expanded = locations.view(M, 1, 2).expand(M, N, 2) # M x N x 2
        centers_expanded = centers.view(1, N, 2).expand(M, N, 2) # M x N x 2
        strides_expanded = strides.view(M, 1, 1).expand(M, N, 2) # M x N
        centers_discret = ((centers_expanded / strides_expanded).int() * \
            strides_expanded).float() + strides_expanded / 2 # M x N x 2
        dist_x = (locations_expanded[:, :, 0] - centers_discret[:, :, 0]).abs()
        dist_y = (locations_expanded[:, :, 1] - centers_discret[:, :, 1]).abs()
        return (dist_x <= strides_expanded[:, :, 0]) & \
            (dist_y <= strides_expanded[:, :, 0])


    def inference(self, images, clss_per_level, reg_pred_per_level, 
        agn_hm_pred_per_level, grids):
        logits_pred = [x.sigmoid() if x is not None else None \
            for x in clss_per_level]
        agn_hm_pred_per_level = [x.sigmoid() if x is not None else None \
            for x in agn_hm_pred_per_level]

        if self.only_proposal:
            proposals = self.predict_instances(
                grids, agn_hm_pred_per_level, reg_pred_per_level, 
                images.image_sizes, [None for _ in agn_hm_pred_per_level])
        else:
            proposals = self.predict_instances(
                grids, logits_pred, reg_pred_per_level, 
                images.image_sizes, agn_hm_pred_per_level)
        if self.as_proposal or self.only_proposal:
            for p in range(len(proposals)):
                proposals[p].proposal_boxes = proposals[p].get('pred_boxes')
                proposals[p].objectness_logits = proposals[p].get('scores')
                proposals[p].remove('pred_boxes')

        if self.debug:
            debug_test(
                [self.denormalizer(x) for x in images], 
                logits_pred, reg_pred_per_level, 
                agn_hm_pred_per_level, preds=proposals,
                vis_thresh=self.vis_thresh, 
                debug_show_name=False)
        return proposals, {}


    @torch.no_grad()
    def gen_seperate_ind(self, boxes):
        rb_y = boxes[:, 3]
        _, indices = torch.topk(rb_y, k=rb_y.shape[0])
        ind_b0, ind_b1 = indices[:2048], indices[2048:]
        return ind_b0, ind_b1, indices

    def batched_nms_npu(self, boxes, scores, idxs, iou_threshold):
        """
        模仿npu算子nms_with_mask的行为，输出mask，迁移到npu后，这部分替换为NPU上的算子
        result_mask的shape和scores相同
        """
        '''
        result_mask = scores.new_zeros(scores.size(), dtype=torch.bool)
        for id in torch.jit.annotate(List[int], torch.unique(idxs).cpu().tolist()):
            mask = (idxs == id).nonzero().view(-1)
            keep = nms(boxes[mask], scores[mask], iou_threshold)
            result_mask[mask[keep]] = True
        '''
        boxes = boxes.cpu()
        scores = scores.cpu()
        idxs = idxs.cpu()
        
        result_mask = scores.new_zeros(scores.size(), dtype=torch.bool)
        # # for id in torch.jit.annotate(List[int], torch.unique(idxs).cpu().tolist()):
        #     # mask = (idxs == id).nonzero().view(-1)
        keep = nms(boxes, scores, iou_threshold)
        result_mask[keep] = True
        return result_mask.npu()

        # _, _, keep_mask = torch_npu.npu_nms_with_mask(torch.cat([boxes, scores[..., None]], 1), iou_threshold)
        # return keep_mask

    def nms_and_topK_v2(self, boxlists, nms=True):
        num_images = len(boxlists)
        results = []
        # 遍历图片
        for i in range(num_images):
            nms_thresh = self.nms_thresh_train if self.training else \
                self.nms_thresh_test
            # 对于每张图片，遍历level，共5个level
            level_keep_list = []
            level_scores_per_img = []
            for index, box in enumerate(boxlists[i]):
                if nms:
                    if box.has('pred_boxes'):
                        boxes = box.pred_boxes.tensor
                        labels = box.pred_classes
                    else:
                        boxes = box.proposal.tensor
                        labels = box.proposal.tensor.new_zeros(
                            len(box.proposal.tensor))
                
                    scores = box.scores
                    if boxes.shape[0] <= 2048:
                        # npu算子只支持最大2048个框的输入
                        keep_mask = self.batched_nms_npu(boxes, scores, labels, nms_thresh)
                    else:
                        # 大于2048个框，需要进行切分，由于pre_nms_topk只有4000，只需要切分一次
                        '''keep_mask = torch.zeros_like(scores).bool()
                        # 通过3次随机，试图提高精度，但是效果不好
                        for iter in range(3):
                            shuffle = torch.randperm(len(scores)).long()
                            boxes_shuffled = boxes[shuffle]
                            scores_shuffled = box.scores[shuffle]
                            labels_shuffled = box.pred_classes[shuffle]

                            boxes0 = boxes_shuffled[:2048]
                            scores0 = scores_shuffled[:2048]
                            labels0 = labels_shuffled[:2048]
                            keep_mask0 = self.batched_nms_npu(boxes0, scores0, labels0, nms_thresh)
                            boxes1 = boxes_shuffled[2048:]
                            scores1 = scores_shuffled[2048:]
                            labels1 = labels_shuffled[2048:]
                            keep_mask1 = self.batched_nms_npu(boxes1, scores1, labels1, nms_thresh)
                            keep_mask_ = torch.cat([keep_mask0, keep_mask1], 0)
                            keep_mask_[shuffle] = keep_mask_.clone()
                            keep_mask = torch.logical_or(keep_mask, keep_mask_.bool())'''

                        ind_b0, ind_b1, ind = self.gen_seperate_ind(boxes)
                        boxes0 = boxes[ind_b0]
                        scores0 = scores[ind_b0]
                        labels0 = box.pred_classes[ind_b0]
                        keep_mask0 = self.batched_nms_npu(boxes0, scores0, labels0, nms_thresh)

                        boxes1 = boxes[ind_b1]
                        scores1 = scores[ind_b1]
                        labels1 = box.pred_classes[ind_b1]
                        keep_mask1 = self.batched_nms_npu(boxes1, scores1, labels1, nms_thresh)
                        
                        keep_mask = torch.cat((keep_mask0, keep_mask1), 0)
                        keep_mask[ind] = keep_mask.clone()
 
                    
                    level_scores_per_img.append(scores)
                    level_keep_list.append(keep_mask)
            
            keep_mask = cat(level_keep_list, dim=0)
            scores_per_img = cat(level_scores_per_img, dim=0)
            scores_per_img = scores_per_img * keep_mask.float()
            post_nms_topk = self.post_nms_topk_train if self.training else self.post_nms_topk_test #2000
            _, indice = torch.topk(scores_per_img, post_nms_topk)
            boxeslist_per_img = Instances.cat(boxlists[i])
            boxeslist_per_img = boxeslist_per_img[indice]
            results.append(boxeslist_per_img)
        return results


    def predict_instances(
        self, grids, logits_pred, reg_pred, image_sizes, agn_hm_pred, 
        is_proposal=False):
        sampled_boxes = []
        for l in range(len(grids)):
            sampled_boxes.append(self.predict_single_level(
                grids[l], logits_pred[l], reg_pred[l] * self.strides[l],
                image_sizes, agn_hm_pred[l], l, is_proposal=is_proposal))
        boxlists = list(zip(*sampled_boxes))
        # boxlists = [Instances.cat(boxlist) for boxlist in boxlists]
        # boxlists = self.nms_and_topK(
            # boxlists, nms=not self.not_nms)
        #fixed shape start
        boxlists = self.nms_and_topK_v2(
            boxlists, nms=not self.not_nms)
        #fixed shape end
        return boxlists


    def predict_single_level(
        self, grids, heatmap, reg_pred, image_sizes, agn_hm, level, 
        is_proposal=False):
        N, C, H, W = heatmap.shape
        # if C!=1:
        #     print(C)
        # put in the same format as grids
        if self.center_nms:
            heatmap_nms = nn.functional.max_pool2d(
                heatmap, (3, 3), stride=1, padding=1)
            heatmap = heatmap * (heatmap_nms == heatmap).float()
        heatmap = heatmap.permute(0, 2, 3, 1) # N x H x W x C
        heatmap = heatmap.reshape(N, -1, C) # N x HW x C
        box_regression = reg_pred.view(N, 4, H, W).permute(0, 2, 3, 1) # N x H x W x 4 
        box_regression = box_regression.reshape(N, -1, 4) # N x HW x 4

        candidate_inds = heatmap > self.score_thresh # 0.05
        # pre_nms_top_n = torch.ones_like(candidate_inds).view(N, -1).sum(1) # N
        pre_nms_top_n = candidate_inds.view(N, -1).sum(1) # N
        pre_nms_topk = self.pre_nms_topk_train if self.training else self.pre_nms_topk_test # 2000 in train
        pre_nms_top_n = pre_nms_top_n.clamp(max=pre_nms_topk) # N

        if agn_hm is not None:
            agn_hm = agn_hm.view(N, 1, H, W).permute(0, 2, 3, 1)
            agn_hm = agn_hm.reshape(N, -1)
            heatmap = heatmap * agn_hm[:, :, None]

        results = []
        
        # C = 1
        for i in range(N):
            per_box_cls = heatmap[i] # HW x C

            #fixed shape start

            # per_candidate_inds = candidate_inds[i] # n
            # per_box_cls = per_box_cls[per_candidate_inds] # n

            # per_candidate_nonzeros = per_candidate_inds.nonzero() # n
            # per_box_loc = per_candidate_nonzeros[:, 0] # n
            # per_class = per_candidate_nonzeros[:, 1] # n

            # per_box_regression = box_regression[i] # HW x 4
            # per_box_regression = per_box_regression[per_box_loc] # n x 4
            # per_grids = grids[per_box_loc] # n x 2

            per_candidate_inds = candidate_inds[i].squeeze() # n 
            # per_box_cls = torch.where(per_candidate_inds == 0.0, torch.zeros_like(per_candidate_inds), per_box_cls)
            per_box_cls = per_box_cls.squeeze() * per_candidate_inds
            # per_candidate_inds = per_candidate_inds1.squeeze()
            per_class = torch.zeros_like(per_candidate_inds)
            per_box_regression = box_regression[i] # HW x 4
            # per_box_regression = torch.where(per_candidate_inds.repeat(1,4) == 0.0, torch.zeros_like(box_regression), per_box_regression)
            # per_grids = torch.where(per_candidate_inds.repeat(1,2) == 0.0, torch.zeros_like(grids), grids)
            
            per_box_regression = per_box_regression * per_candidate_inds.unsqueeze(1)
            per_grids = grids * per_candidate_inds.unsqueeze(1)
            
            
            #fixed shape end

            # per_pre_nms_top_n = pre_nms_top_n[i] # 1

            # if per_candidate_inds.sum().item() > per_pre_nms_top_n.item():
            if H * W > pre_nms_topk:
                per_box_cls, top_k_indices = \
                    per_box_cls.topk(pre_nms_topk, sorted=False)
                # import pdb; pdb.set_trace()
                per_class = per_class[top_k_indices]
                per_box_regression = per_box_regression[top_k_indices]
                per_grids = per_grids[top_k_indices]
            
            detections = torch.stack([
                per_grids[:, 0] - per_box_regression[:, 0],
                per_grids[:, 1] - per_box_regression[:, 1],
                per_grids[:, 0] + per_box_regression[:, 2],
                per_grids[:, 1] + per_box_regression[:, 3],
            ], dim=1) # n x 4

            # avoid invalid boxes in RoI heads
            detections[:, 2] = torch.max(detections[:, 2], detections[:, 0] + 0.01)
            detections[:, 3] = torch.max(detections[:, 3], detections[:, 1] + 0.01)
            boxlist = Instances(image_sizes[i])
            # boxlist = Instances([1344,1344]) #fixed shape
            boxlist.scores = torch.sqrt(per_box_cls) \
                if self.with_agn_hm else per_box_cls # n
            # import pdb; pdb.set_trace()
            boxlist.pred_boxes = Boxes(detections)
            boxlist.pred_classes = per_class
            results.append(boxlist)
        return results


    def nms_and_topK(self, boxlists, nms=True):
        num_images = len(boxlists)
        results = []
        for i in range(num_images):
            nms_thresh = self.nms_thresh_train if self.training else \
                self.nms_thresh_test
            result = ml_nms(boxlists[i], nms_thresh) if nms else boxlists[i]
            if self.debug:
                print('#proposals before nms', len(boxlists[i]))
                print('#proposals after nms', len(result))
            num_dets = len(result)
            post_nms_topk = self.post_nms_topk_train if self.training else \
                self.post_nms_topk_test
            if num_dets > post_nms_topk:
                cls_scores = result.scores
                image_thresh, _ = torch.kthvalue(
                    cls_scores.cpu(),
                    num_dets - post_nms_topk + 1
                )
                keep = cls_scores >= image_thresh.item()
                keep = torch.nonzero(keep).squeeze(1)
                result = result[keep]
            if self.debug:
                print('#proposals after filter', len(result))
            results.append(result)
        return results


    def _add_more_pos(self, reg_pred, gt_instances, shapes_per_level):
        labels, level_masks, c33_inds, c33_masks, c33_regs = \
            self._get_c33_inds(gt_instances, shapes_per_level)
        N, L, K = labels.shape[0], len(self.strides), 9
        c33_inds[c33_masks == 0] = 0
        reg_pred_c33 = reg_pred[c33_inds].detach() # N x L x K
        invalid_reg = c33_masks == 0
        c33_regs_expand = c33_regs.view(N * L * K, 4).clamp(min=0)
        if N > 0:
            with torch.no_grad():
                c33_reg_loss = self.iou_loss(
                    reg_pred_c33.view(N * L * K, 4), 
                    c33_regs_expand, None,
                    reduction='none').view(N, L, K).detach() # N x L x K
        else:
            c33_reg_loss = reg_pred_c33.new_zeros((N, L, K)).detach()
        c33_reg_loss[invalid_reg] = INF # N x L x K
        c33_reg_loss.view(N * L, K)[level_masks.view(N * L), 4] = 0 # real center
        c33_reg_loss = c33_reg_loss.view(N, L * K)
        if N == 0:
            loss_thresh = c33_reg_loss.new_ones((N)).float()
        else:
            loss_thresh = torch.kthvalue(
                c33_reg_loss, self.more_pos_topk, dim=1)[0] # N
        loss_thresh[loss_thresh > self.more_pos_thresh] = self.more_pos_thresh # N
        new_pos = c33_reg_loss.view(N, L, K) < \
            loss_thresh.view(N, 1, 1).expand(N, L, K)
        pos_inds = c33_inds[new_pos].view(-1) # P
        labels = labels.view(N, 1, 1).expand(N, L, K)[new_pos].view(-1)
        return pos_inds, labels
        
    
    def _get_c33_inds(self, gt_instances, shapes_per_level):
        '''
        TODO (Xingyi): The current implementation is ugly. Refactor.
        Get the center (and the 3x3 region near center) locations of each objects
        Inputs:
            gt_instances: [n_i], sum n_i = N
            shapes_per_level: L x 2 [(h_l, w_l)]_L
        '''
        labels = []
        level_masks = []
        c33_inds = []
        c33_masks = []
        c33_regs = []
        L = len(self.strides)
        B = len(gt_instances)
        shapes_per_level = shapes_per_level.long()
        loc_per_level = (shapes_per_level[:, 0] * shapes_per_level[:, 1]).long() # L
        level_bases = []
        s = 0
        for l in range(L):
            level_bases.append(s)
            s = s + B * loc_per_level[l]
        level_bases = shapes_per_level.new_tensor(level_bases).long() # L
        strides_default = shapes_per_level.new_tensor(self.strides).float() # L
        K = 9
        dx = shapes_per_level.new_tensor([-1, 0, 1, -1, 0, 1, -1, 0, 1]).long()
        dy = shapes_per_level.new_tensor([-1, -1, -1, 0, 0, 0, 1, 1, 1]).long()
        for im_i in range(B):
            targets_per_im = gt_instances[im_i]
            bboxes = targets_per_im.gt_boxes.tensor # n x 4
            n = bboxes.shape[0]
            if n == 0:
                continue
            centers = ((bboxes[:, [0, 1]] + bboxes[:, [2, 3]]) / 2) # n x 2
            centers = centers.view(n, 1, 2).expand(n, L, 2)

            strides = strides_default.view(1, L, 1).expand(n, L, 2) # 
            centers_inds = (centers / strides).long() # n x L x 2
            center_grids = centers_inds * strides + strides // 2# n x L x 2
            l = center_grids[:, :, 0] - bboxes[:, 0].view(n, 1).expand(n, L)
            t = center_grids[:, :, 1] - bboxes[:, 1].view(n, 1).expand(n, L)
            r = bboxes[:, 2].view(n, 1).expand(n, L) - center_grids[:, :, 0]
            b = bboxes[:, 3].view(n, 1).expand(n, L) - center_grids[:, :, 1] # n x L
            reg = torch.stack([l, t, r, b], dim=2) # n x L x 4
            reg = reg / strides_default.view(1, L, 1).expand(n, L, 4).float()
            
            Ws = shapes_per_level[:, 1].view(1, L).expand(n, L)
            Hs = shapes_per_level[:, 0].view(1, L).expand(n, L)
            expand_Ws = Ws.view(n, L, 1).expand(n, L, K)
            expand_Hs = Hs.view(n, L, 1).expand(n, L, K)
            label = targets_per_im.gt_classes.view(n).clone()
            mask = reg.min(dim=2)[0] >= 0 # n x L
            mask = mask & self.assign_fpn_level(bboxes)
            labels.append(label) # n
            level_masks.append(mask) # n x L

            Dy = dy.view(1, 1, K).expand(n, L, K)
            Dx = dx.view(1, 1, K).expand(n, L, K)
            c33_ind = level_bases.view(1, L, 1).expand(n, L, K) + \
                       im_i * loc_per_level.view(1, L, 1).expand(n, L, K) + \
                       (centers_inds[:, :, 1:2].expand(n, L, K) + Dy) * expand_Ws + \
                       (centers_inds[:, :, 0:1].expand(n, L, K) + Dx) # n x L x K
            
            c33_mask = \
                ((centers_inds[:, :, 1:2].expand(n, L, K) + dy) < expand_Hs) & \
                ((centers_inds[:, :, 1:2].expand(n, L, K) + dy) >= 0) & \
                ((centers_inds[:, :, 0:1].expand(n, L, K) + dx) < expand_Ws) & \
                ((centers_inds[:, :, 0:1].expand(n, L, K) + dx) >= 0)
            # TODO (Xingyi): think about better way to implement this
            # Currently it hard codes the 3x3 region
            c33_reg = reg.view(n, L, 1, 4).expand(n, L, K, 4).clone()
            c33_reg[:, :, [0, 3, 6], 0] -= 1
            c33_reg[:, :, [0, 3, 6], 2] += 1
            c33_reg[:, :, [2, 5, 8], 0] += 1
            c33_reg[:, :, [2, 5, 8], 2] -= 1
            c33_reg[:, :, [0, 1, 2], 1] -= 1
            c33_reg[:, :, [0, 1, 2], 3] += 1
            c33_reg[:, :, [6, 7, 8], 1] += 1
            c33_reg[:, :, [6, 7, 8], 3] -= 1
            c33_mask = c33_mask & (c33_reg.min(dim=3)[0] >= 0) # n x L x K
            c33_inds.append(c33_ind)
            c33_masks.append(c33_mask)
            c33_regs.append(c33_reg)
        
        if len(level_masks) > 0:
            labels = torch.cat(labels, dim=0)
            level_masks = torch.cat(level_masks, dim=0)
            c33_inds = torch.cat(c33_inds, dim=0).long()
            c33_regs = torch.cat(c33_regs, dim=0)
            c33_masks = torch.cat(c33_masks, dim=0)
        else:
            labels = shapes_per_level.new_zeros((0)).long()
            level_masks = shapes_per_level.new_zeros((0, L)).bool()
            c33_inds = shapes_per_level.new_zeros((0, L, K)).long()
            c33_regs = shapes_per_level.new_zeros((0, L, K, 4)).float()
            c33_masks = shapes_per_level.new_zeros((0, L, K)).bool()
        return labels, level_masks, c33_inds, c33_masks, c33_regs # N x L, N x L x K
